import os
import sys
from datetime import datetime
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.utils.tensorboard import SummaryWriter
import torchvision.transforms as transforms
from models.ResNet_LSTM import Res_LSTM
from dataset import CSL_Isolated
from train import train_epoch
from validation import val_epoch

# # python的str默认是ascii编码，和unicode编码冲突，设置为utf-8；
# reload(sys)
# sys.setdefaultencoding('utf8')

# Path setting
# data_path = "D:\Works\Hisi_codes\SLRDataset_isolate\color"
# label_path = "D:\Works\Hisi_codes\SLRDataset_isolate\gloss_label_2_utf8.txt"
# model_path = "D:\Works\Hisi_codes\SLR-master\\reslstm_models"
data_path = "/home/hs18352379891/datasets/SLR_Dataset_images"
# data_path = "/home/hs18352379891/datasets/SLR_Dataset_videos"
label_path = "/home/hs18352379891/datasets/labels_101_utf8.txt"
model_path = "/home/hs18352379891/SLR-master-server/res50lstm_models"
log_path = "log/res50lstm_{:%Y-%m-%d_%H-%M-%S}.log".format(datetime.now())
sum_path = "runs/res50lstm_{:%Y-%m-%d_%H-%M-%S}".format(datetime.now())

wights_path = "/home/hs18352379891/SLR-master-server/res50lstm_models" # 权重路径

# Log to file & tensorboard writer
logging.basicConfig(level=logging.INFO, format='%(message)s', handlers=[logging.FileHandler(log_path), logging.StreamHandler()])
logger = logging.getLogger('SLR')
logger.info('Logging to file...')
writer = SummaryWriter(sum_path)


# GPU编号设置 在运行训练程序时，需在终端内同步指定GPU号
os.environ["CUDA_VISIBLE_DEVICES"]="1"
# Device setting
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
# device = torch.device("cuda")
print("device: ", device)

# Hyperparams
epochs = 100
batch_size = 128  # 官方服务器，试试大一点
learning_rate = 1e-4
weight_decay = 1e-5
log_interval = 20
sample_size = 128
sample_duration = 8 # 一次输入的帧数
num_classes = 101  # 暂时先训练101个词
lstm_hidden_size = 512
lstm_num_layers = 1
# attention = False
pretrain = False # 是否读取预训练权重
# ResNet网络选择
# arch = "resnet18"
# arch = "resnet34" 
arch = "resnet50" 

# Train with Conv+LSTM
if __name__ == '__main__':
    # Load data
    transform = transforms.Compose([transforms.Resize([sample_size, sample_size]),
                                    transforms.ToTensor(),
                                    transforms.Normalize(mean=[0.5], std=[0.5])])
    train_set = CSL_Isolated(data_path=data_path, label_path=label_path, frames=sample_duration,
        num_classes=num_classes, train=True, transform=transform)
    val_set = CSL_Isolated(data_path=data_path, label_path=label_path, frames=sample_duration,
        num_classes=num_classes, train=False, transform=transform)
    logger.info("Dataset samples: {}".format(len(train_set)+len(val_set)))
    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
    print("train_loader finshed")
    val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
    print("val_loader finshed")
    
    # Create model
    print("Creating model...", end="")
    # model = Res_LSTM(sample_size=sample_size, sample_duration=sample_duration, num_classes=num_classes).to(device)
    model = Res_LSTM(sample_size=sample_size, sample_duration=sample_duration, num_classes=num_classes, arch=arch).to(device)
    print("\tCreate model finshed")

    if pretrain:# 读取之前训练的权重
        print("Load weights from : ", end="")
        weight_list = sorted(os.listdir(wights_path)) # 权重文件夹内的所有权重
        print(os.path.join(wights_path, weight_list[-1]))
        model.load_state_dict(torch.load(os.path.join(wights_path, weight_list[-1]))) # 取最后一个权重，一般是最新的
    
    # # Run the model parallelly # 使用多个GPU进行训练，在比赛服务器不允许
    # if torch.cuda.device_count() > 1:
    #     logger.info("Using {} GPUs".format(torch.cuda.device_count()))
    #     model = nn.DataParallel(model)
    
    # Create loss criterion & optimizer
    print("Create loss criterion & optimizer")
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Start training
    logger.info("Training Started".center(60, '#'))
    for epoch in range(epochs):
        # Train the model
        train_epoch(model, criterion, optimizer, train_loader, device, epoch, logger, log_interval, writer)

        # Validate the model
        val_epoch(model, criterion, val_loader, device, epoch, logger, writer)

        # Save model
        torch.save(model.state_dict(), os.path.join(model_path, "res50lstm_epoch{:03d}.pth".format(epoch+1)))
        logger.info("Epoch {} Model Saved".format(epoch+1).center(60, '#'))

    logger.info("Training Finished".center(60, '#'))
